/* Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 * ===================================================================================================================*/

#include <gtest/gtest.h>
#include <climits>

#define protected public
#define private public

#include "inc/external/register/register_pass.h"
#include "graph/debug/ge_log.h"
#include "register/custom_pass_helper.h"

#define protected public
#define private public

namespace ge {
class UtestRegisterPass : public testing::Test { 
 protected:
  void SetUp() {}
  void TearDown() {}
};

TEST_F(UtestRegisterPass, GetPriorityTest) {
  ge::PassRegistrationData pass_data;
  pass_data.impl_ = nullptr;
  int32_t ret = pass_data.GetPriority();
  EXPECT_EQ(ret, INT_MAX);
}

TEST_F(UtestRegisterPass, GetPriorityImplIsNotNull) {
  ge::PassRegistrationData pass_data("register_pass");
  int32_t ret = pass_data.GetPriority();
  EXPECT_EQ(ret, INT_MAX);
}

TEST_F(UtestRegisterPass, GetPassNameTest) {
  ge::PassRegistrationData pass_data("registry");
  std::string name = pass_data.GetPassName();
  EXPECT_EQ(name, "registry");

  pass_data.impl_ = nullptr;
  name = pass_data.GetPassName();
  EXPECT_EQ(name, "");
}

TEST_F(UtestRegisterPass, PriorityTest) {
  ge::PassRegistrationData pass_data("registry");
  int32_t priority = -1;
  pass_data.Priority(priority);
  int32_t ret = pass_data.GetPriority();
  EXPECT_EQ(ret, INT_MAX);

  priority = 2;
  pass_data.Priority(priority);
  ret = pass_data.GetPriority();
  EXPECT_EQ(ret, 2);
}

TEST_F(UtestRegisterPass, CustomPassFnTest) {
  CustomPassFunc custom_pass_fn = nullptr;
  ge::PassRegistrationData pass_data("registry");
  pass_data.CustomPassFn(custom_pass_fn);
  auto ret = pass_data.GetCustomPassFn();
  EXPECT_EQ(ret, nullptr);

  custom_pass_fn = std::function<Status(ge::GraphPtr &)>();
  pass_data.impl_ = nullptr;
  pass_data.CustomPassFn(custom_pass_fn);
  ret = pass_data.GetCustomPassFn();
  EXPECT_EQ(ret, nullptr);
}

TEST_F(UtestRegisterPass, CustomPassHelperRunTest) {
  PassRegistrationData pass_data("registry");
  ge::PassReceiver pass_receiver(pass_data);
  CustomPassHelper cust_helper;
  auto graph = std::make_shared<Graph>("test");
  bool ret = cust_helper.Run(graph);
  EXPECT_EQ(ret, SUCCESS);

  PassRegistrationData pass_data2("registry2");
  cust_helper.registration_datas_.insert(pass_data2);
  auto graph2 = std::make_shared<Graph>("test2");
  ret = cust_helper.Run(graph2);
  EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge
